#!/usr/bin/python
#
# Copyright (c) 2011 Vince Durham
# Distributed under the MIT/X11 software license, see the accompanying
# file COPYING.
#

import logging
import argparse
import os
import sys
import traceback
import json
import base64
import socket

from datetime import datetime

from twisted.internet import defer, reactor, threads
from twisted.web import server, resource
from twisted.internet.error import ConnectionRefusedError
import twisted.internet.error
from urlparse import urlsplit
import httplib
import thread

__version__ = '0.2.2'

'''
merge-mine-proxy

Run behind a pool or a miner to mine a parent chain and a set of auxiliary chains.

Output is in the form:

2011-07-07T00:00:00,solve,1,1,HASH

Where the fields are:

    * UTC date and time in ISO format
    * The word "solve"
    * 1 if the proof of work was accepted by the parent chain
    * 1 if the proof of work was accepted by each aux chain
    * HASH parent block hash

'''

AUX_UPDATE_INTERVAL = 5
MERKLE_TREES_TO_KEEP = 24

logger = logging.getLogger('merged-mine-proxy')
logger.setLevel(logging.DEBUG)

def reverse_chunks(s, l):
    return ''.join(reversed([s[x:x+l] for x in xrange(0, len(s), l)]))

def getresponse(http, path, postdata, headers):
    http.request(path, 'POST', postdata, headers)
    return http.getresponse().read()

class Error(Exception):
    def __init__(self, code, message, data=''):
        if not isinstance(code, int):
            raise TypeError('code must be an int')
        if not isinstance(message, unicode):
            raise TypeError('message must be a unicode')
        self._code, self._message, self._data = code, message, data
    def __str__(self):
        return '%i %s %r' % (self._code, self._message, self._data)
    def _to_obj(self):
        return {
            'code': self._code,
            'message': self._message,
            'data': self._data,
        }

class Proxy(object):
    def __init__(self, url):
        (schema, netloc, path, query, fragment) = urlsplit(url)
        auth = None
        if netloc.find('@') >= 0:
            (auth, netloc) = netloc.split("@")
        if path == "":
            path = "/"
        self._url = "%s://%s%s" % (schema, netloc, path)
        self._path = path
        self._auth = auth
        self._netloc = netloc
        self._http = None
    
    def callRemote(self, method, *params):
        try:
            if self._http is None:
                (host, port) = self._netloc.split(":")
                self._http = httplib.HTTPConnection(host, port)
                try:
                    self._http.connect()
                except socket.error:
                    raise httplib.HTTPException()

            id_ = 0
            
            headers = {
                'Content-Type': 'text/json',
            }
            if self._auth is not None:
                headers['Authorization'] = 'Basic ' + base64.b64encode(self._auth)
            resp = None

            postdata=json.dumps({
                'jsonrpc': '2.0',
                'method': method,
                'params': params,
                'id': id_,
            })

            content = getresponse(self._http, self._path, postdata, headers)

            resp = json.loads(content)
            
            if resp['id'] != id_:
                raise ValueError('invalid id')
            if 'error' in resp and resp['error'] is not None:
                raise Error(resp['error']['code'], resp['error']['message'])
            return resp['result']
        except httplib.HTTPException:
            self._http = None
            logger.error("Could not connect to %s", self._url)
            raise Error(-32099, u'Could not connect to backend', self._url)
    
    def __getattr__(self, attr):
        if attr.startswith('rpc_'):
            return lambda *params: self.callRemote(attr[len('rpc_'):], *params)
        raise AttributeError('%r object has no attribute %r' % (self.__class__.__name__, attr))

class Server(resource.Resource):
    extra_headers = None
    
    def render(self, request):
        def finish(x):
            if request._disconnected:
                return
            if x is not None:
                request.write(x)
            request.finish()
        
        def finish_error(fail):
            if request._disconnected:
                return
            request.setResponseCode(500) # won't do anything if already written to
            request.write('---ERROR---')
            request.finish()
            fail.printTraceback()
        
        defer.maybeDeferred(resource.Resource.render, self, request).addCallbacks(finish, finish_error)
        return server.NOT_DONE_YET

    @defer.inlineCallbacks
    def render_POST(self, request):
        # missing batching, 1.0 notifications
        data = request.content.read()
        
        if self.extra_headers is not None:
            for name, value in self.extra_headers.iteritems():
                request.setHeader(name, value)
        
        try:
            try:
                req = json.loads(data)
            except Exception:
                raise RemoteError(-32700, u'Parse error')
        except Error, e:
            # id unknown
            request.write(json.dumps({
                'jsonrpc': '2.0',
                'id': None,
                'result': None,
                'error': e._to_obj(),
            }))
        
        id_ = req.get('id', None)
        
        try:
            try:
                method = req['method']
                if not isinstance(method, unicode):
                    raise ValueError()
                params = req.get('params', [])
                if not isinstance(params, list):
                    raise ValueError()
            except Exception:
                raise Error(-32600, u'Invalid Request')
            
            method_name = 'rpc_' + method
            if not hasattr(self, method_name):
                raise Error(-32601, u'Method not found')
            method_meth = getattr(self, method_name)
            
            df = defer.maybeDeferred(method_meth, *params)
            
            if id_ is None:
                return
            
            try:
                result = yield df
            #except Error, e:
            #w    raise e
            except Exception, e:
                logger.error(str(e))
                raise Error(-32099, u'Unknown error: ' + str(e))
            
            res = json.dumps({
                'jsonrpc': '2.0',
                'id': id_,
                'result': result,
                'error': None,
            })
            request.setHeader('content-length', str(len(res)))
            request.write(res)
        except Error, e:
            res = json.dumps({
                'jsonrpc': '2.0',
                'id': id_,
                'result': None,
                'error': e._to_obj(),
            })
            request.setHeader('content-length', str(len(res)))
            request.write(res)

class Listener(Server):
    def __init__(self, parent, auxs, merkle_size, rewrite_target):
        Server.__init__(self)
        self.parent = parent
        self.auxs = auxs
        self.chain_ids = [None for i in auxs]
        self.aux_targets = [None for i in auxs]
        self.merkle_size = merkle_size
        self.merkle_tree_queue = []
        self.merkle_trees = {}
        self.rewrite_target = None
        if rewrite_target == 1:
            self.rewrite_target = reverse_chunks("00000000fffffffffffffffffffffffffffffffffffffffffffffffffffffffe", 2)
        elif rewrite_target == 100:
            self.rewrite_target = reverse_chunks("00000000028f5c28000000000000000000000000000000000000000000000000", 2)
        if merkle_size > 255:
            raise ValueError('merkle size up to 255')
        self.putChild('', self)
    
    def merkle_branch(self, chain_index, merkle_tree):
        step = self.merkle_size
        i1 = chain_index
        j = 0
        branch = []
        while step > 1:
            i = min(i1^1, step-1)
            branch.append(merkle_tree[i + j])
            i1 = i1 >> 1
            j += step
            step = (step + 1) / 2
        return branch

    def calc_merkle_index(self, chain):
        chain_id = self.chain_ids[chain]
        rand = 0 # nonce
        rand = (rand * 1103515245 + 12345) & 0xffffffff;
        rand += chain_id;
        rand = (rand * 1103515245 + 12345) & 0xffffffff;
        return rand % self.merkle_size

    @defer.inlineCallbacks
    def update_auxs(self):
        # create merkle leaves with arbitrary initial value
        merkle_leaves = [ ('0' * 63) + ("%02x" % x) for x in range(self.merkle_size) ]

        # ask each aux chain for a block
        for chain in range(len(self.auxs)):
            aux_block = (yield self.auxs[chain].rpc_getauxblock())
            aux_block_hash = aux_block['hash']
            self.chain_ids[chain] = aux_block['chainid']
            chain_merkle_index = self.calc_merkle_index(chain)
            merkle_leaves[chain_merkle_index] = aux_block_hash
            self.aux_targets[chain] = reverse_chunks(aux_block['target'], 2) # fix endian

        # create merkle tree
        merkle_tree = (yield self.parent.rpc_buildmerkletree(*merkle_leaves))
        merkle_root = merkle_tree[-1]

        if not self.merkle_trees.has_key(merkle_root):
            # remember new tree
            self.merkle_trees[merkle_root] = merkle_tree
            self.merkle_tree_queue.append(merkle_root)
            if len(self.merkle_tree_queue) > MERKLE_TREES_TO_KEEP:
                # forget one tree
                old_root = self.merkle_tree_queue.pop(0)
                del self.merkle_trees[old_root]

    def update_aux_process(self):
        reactor.callLater(AUX_UPDATE_INTERVAL, self.update_aux_process)
        self.update_auxs()

    def rpc_getaux(self, data=None):
        ''' Use this rpc call to get the aux chain merkle root and aux target.  Pool software
        can then call getworkaux(aux) instead of going through this proxy.  It is enough to call this
        once a second.
        '''
        try:
            # Get aux based on the latest tree
            merkle_root = self.merkle_tree_queue[-1]
            # nonce = 0, one byte merkle size
            aux = merkle_root + ("%02x000000" % self.merkle_size) + "00000000"
            result = {'aux': aux}

            if self.rewrite_target:
                result['aux_target'] = self.rewrite_target
            else:
                # Find highest target
                targets = []
                targets.extend(self.aux_targets)
                targets.sort()
                result['aux_target'] = reverse_chunks(targets[-1], 2) # fix endian
            return result
        except Exception:
            logger.error(traceback.format_exc())
            raise

    @defer.inlineCallbacks
    def rpc_getwork(self, data=None):
        ''' This rpc call generates bitcoin miner compatible work from data received
        from the aux and parent chains.
        '''
        try:
            if data:
                # Submit work upstream
                any_solved = False
                aux_solved = []

                # get merkle root
                solution = (yield self.parent.rpc_getworkaux("", data))

                if solution is False:
                    logger.error("stale work")
                    defer.returnValue(False)
                    return

                parent_hash = solution['hash']
                merkle_root = solution['aux'][:-16] # strip off size and nonce
                if not self.merkle_trees.has_key(merkle_root):
                    logger.error("stale merkle root %s", merkle_root)
                    defer.returnValue(False)
                    return

                merkle_tree = self.merkle_trees[merkle_root]

                # submit to each aux chain
                for chain in range(len(self.auxs)):
                    chain_merkle_index = self.calc_merkle_index(chain)
                    aux_solved.append(False)
                    # try submitting if under target
                    if self.aux_targets[chain] > parent_hash and not chain_merkle_index is None:
                        branch = self.merkle_branch(chain_merkle_index, merkle_tree)
                        proof = (
                            yield self.parent.rpc_getworkaux("", data, chain_merkle_index, *branch))
                        if proof is False:
                            logger.error("aux pow request rejected by parent, chain %d", chain)
                        else:
                            aux_hash = merkle_tree[chain_merkle_index]
                            aux_solved[-1] = (
                                yield self.auxs[chain].rpc_getauxblock(aux_hash, proof['auxpow']))
                            any_solved = any_solved or aux_solved[-1]

                # submit to parent
                parent_solved = (yield self.parent.rpc_getworkaux("submit", data))
                any_solved = any_solved or parent_solved

                logger.info("%s,solve,%s,%s,%s", datetime.utcnow().isoformat(),
                                            "1" if parent_solved else "0",
                                            ",".join(["1" if solve else "0" for solve in aux_solved]),
                                            parent_hash)
                defer.returnValue(any_solved)
            else:
                # Get work based on the latest tree
                merkle_root = self.merkle_tree_queue[-1]
                # nonce = 0, one byte merkle size
                aux = merkle_root + ("%02x000000" % self.merkle_size) + "00000000"
                work = (yield self.parent.rpc_getworkaux(aux))
                if self.rewrite_target:
                    work['target'] = self.rewrite_target
                else:
                    # Find highest target
                    targets = [reverse_chunks(work['target'], 2)] # fix endian
                    targets.extend(self.aux_targets)
                    targets.sort()
                    work['target'] = reverse_chunks(targets[-1], 2) # fix endian
                defer.returnValue(work)
        except Exception:
            # Exceptions here are normally already handled by the rpc functions
            #logger.debug(traceback.format_exc())
            raise

def main(args):
    parent = Proxy(args.parent_url)
    aux_urls = args.aux_urls or ['http://un:pw@127.0.0.1:8342/']
    auxs = [Proxy(url) for url in aux_urls]
    if args.merkle_size is None:
        for i in range(8):
            if (1<<i) > len(aux_urls):
                args.merkle_size = i
                logger.info('merkle size = %d', i)
                break

    if len(aux_urls) > args.merkle_size:
        raise ValueError('the merkle size must be at least as large as the number of aux chains')

    if args.pidfile:
        pidfile = open(args.pidfile, 'w')
        pidfile.write(str(os.getpid()))
        pidfile.close()

    listener = Listener(parent, auxs, args.merkle_size, args.rewrite_target)
    listener.update_aux_process()
    reactor.listenTCP(args.worker_port, server.Site(listener))

def run():
    parser = argparse.ArgumentParser(description='merge-mine-proxy (version %s)' % (__version__,))
    parser.add_argument('--version', action='version', version=__version__)
    worker_group = parser.add_argument_group('worker interface')
    worker_group.add_argument('-w', '--worker-port', metavar='PORT',
        help='listen on PORT for RPC connections from miners asking for work and providing responses (default: 9992)',
        type=int, action='store', default=9992, dest='worker_port')

    parent_group = parser.add_argument_group('parent chain (bitcoin) interface')
    parent_group.add_argument('-p', '--parent-url', metavar='PARENT_URL',
                              help='connect to the parent RPC at this address (default: http://un:pw@127.0.0.1:8332/)',
                              type=str, action='store',
                              default='http://un:pw@127.0.0.1:8332/',
                              dest='parent_url')

    aux_group = parser.add_argument_group('aux chain (e.g. namecoin) interface(s)')
    aux_group.add_argument('-x', '--aux-url', metavar='AUX_URL',
                           help='connect to the aux RPC at this address (default: http://un:pw@127.0.0.1:8342/)',
                           type=str, action='store', nargs='+',
                           dest='aux_urls')
    aux_group.add_argument('-s', '--merkle-size', metavar='SIZE',
                           help='use these many entries in the merkle tree.  Must be a power of 2. Default is lowest power of 2 greater than number of aux chains.',
                           type=int, action='store', default=None,
                           dest='merkle_size')

    parser.add_argument('-r', '--rewrite-target', help='rewrite target difficulty to 1',
                           action='store_const', const=1, default=False,
                           dest='rewrite_target')
    parser.add_argument('-R', '--rewrite-target-100', help='rewrite target difficulty to 100',
                           action='store_const', const=100, default=False,
                           dest='rewrite_target')

    parser.add_argument('-i', '--pidfile', metavar='PID', type=str, action='store', default=None, dest='pidfile')
    parser.add_argument('-l', '--logfile', metavar='LOG', type=str, action='store', default=None, dest='logfile')

    args = parser.parse_args()

    if args.logfile:
        logger.addHandler(logging.FileHandler(args.logfile))
    else:
        logger.addHandler(logging.StreamHandler())
    
    reactor.callWhenRunning(main, args)
    reactor.run()

if __name__ == "__main__":
    run()
